package com.chenzhiling.study.remote

import com.chenzhiling.study.Io.{FileToArrayByte, ScalaIo}
import com.chenzhiling.study.util.PropertiesUtil.getString
import com.sshtools.net.SocketTransport
import com.sshtools.sftp.{SftpClient, SftpFile}
import com.sshtools.ssh.{PasswordAuthentication, SshAuthentication, SshConnector}
import com.sshtools.ssh2.Ssh2Client
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.Row
import org.apache.spark.{Partition, SparkContext, TaskContext}

import java.io.InputStream


/**
 * @Author: CHEN ZHI LING
 * @Date: 2021/8/9
 * @Description: spark自定义Rdd,通过将j2ssh读取远程文件,加载成RDD
 * 部分参考 https://blog.csdn.net/wang972779876/article/details/117567248
 */
case class IntPartition(index: Int,path:String) extends Partition
class OnlineRdd (sc: SparkContext,
                 ip: String,
                 port: Int,
                 username: String,
                 password: String,
                 remotePath:String)
  extends RDD[Row](sc, Nil) with Serializable {

  private val SPLIT_NUMBER: Int = getString("file.split.number").toInt


  //继承自RDD[Row] 重写compute,自定义rdd
  override def compute(split: Partition, context: TaskContext): Iterator[Row] =   {
    //建立ssh连接器
    val connector: SshConnector = SshConnector.createInstance()
    //建立通信
    val transport = new SocketTransport(ip, port)
    //获得ssh
    val ssh: Ssh2Client = connector.connect(transport, username)
    //密码验证
    val authentication = new PasswordAuthentication()
    authentication.setPassword(password)
    //获得连接结果
    val i: Int = ssh.authenticate(authentication)
    if(i!=SshAuthentication.COMPLETE) {
      throw new RuntimeException("连接服务器异常")
    }
    val sftpClient = new SftpClient(ssh)
     //获得所有的远程文件名
    val list: Array[SftpFile] = ScalaIo.listRemoteFiles(remotePath, client = sftpClient)
    //转成rows
    val rows: Array[Row] = list.flatMap((file: SftpFile) => sftpFileToRow(file, sftpClient))
    //关闭连接
    sftpClient.isClosed
    ssh.disconnect()
    rows.toIterator
  }

  override protected def getPartitions: Array[Partition] = {
    val arrays = new Array[Partition](1)
    arrays(0) = IntPartition(0, path = remotePath)
    arrays
  }




  /**
   * sftpFile变row
   * @param file 远程服务器文件夹
   * @param sftpClient sftp客户端
   * @return
   */
  private def sftpFileToRow(file: SftpFile,sftpClient:SftpClient): Iterator[Row] = {
    //路径
    val path: String = file.getAbsolutePath
    //文件转流
    val stream: InputStream = sftpClient.getInputStream(file.getAbsolutePath)
    //文件名
    val fileName: String = file.getFilename
    //文件大小 UnsignedInteger64转long
    val size: Long = file.getAttributes.getSize.longValue()
    //切分成迭代器
    val iterator: Iterator[Array[Byte]] = FileToArrayByte.splitRemoteStream(stream,size)
    //后缀名
    val suffix: String = ScalaIo.getFileSuffix(fileName)
    val rows: Iterator[Row] = iterator.map((array: Array[Byte]) => filledSchema(fileName, path, suffix, size, array))
    rows
  }


  /**
   * 填充schema
   * @param fileName 文件名
   * @param path 路径
   * @param suffix 后缀
   * @param size 大小
   * @param array 内容字节数组
   * @return
   */
  private def filledSchema(fileName:String,path:String,suffix:String,size:Long,array:Array[Byte]): Row ={
    val list: List[Array[Byte]] = ScalaIo.splitArrayByNumber(array, SPLIT_NUMBER)
    val byteContent: Row = Row.fromSeq(list)
    Row(fileName, path, suffix, size, byteContent)
  }
}
